import json
import logging
import math
import signal
import time
from argparse import Namespace
from functools import wraps
from itertools import count

import click
import torch
import h5py
import numpy as np

from .bayesopt.bo import BayesianOptimization, shape_step, make_shift
from .bayesopt.gp import GaussianProcess, MeasureMeta, MeasureGoal, KERNELS
from .bayesopt.util import DataSampler

from .utils import grid_search_gamma, interval_schedule
from .cli import QCParams, GPParams, BOParams, ACQUISITION_FNS, OPTIMIZER_SETUPS, TrueSolution, Data, csobj
from .cli import namedtuple_as_dict, final_property, DataLog, option_dict


@click.group()
@click.option('--seed', type=int, default=0xDEADBEEF)
@click.option('--json-log', type=click.Path(writable=True, dir_okay=False))
@click.pass_context
def main(ctx, seed, json_log):
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    logging.getLogger('qiskit').setLevel(logging.WARNING)
    logging.getLogger('').setLevel(logging.INFO)
    torch.manual_seed(seed)

    ctx.ensure_object(Namespace)
    ctx.obj.rng = np.random.default_rng(seed)

    if json_log is not None:
        ctx.obj.json_log = json_log
    else:
        ctx.obj.json_log = None

    def handler(signum, frame):
        raise KeyboardInterrupt()

    signal.signal(signal.SIGTERM, handler)


def wrap_logger(log_fn):
    def wrapper(func):
        @wraps(func)
        def wrapped(model, state):
            error = func(model, state)
            log_fn(model, state, error)
            return error
        return wrapped
    return wrapper


class BayesOptCLI:
    def __init__(self, args, ctx):
        self.args = args
        self.start_time = time.time()
        self.ctx = ctx
        self._dict = {}

    @final_property
    def log(self):
        return DataLog(fname=self.args.output_file, prefix='data')

    @final_property
    def sampler(self):
        return DataSampler(
            self.args.n_qbits,
            self.args.n_layers,
            self.args.j_coupling,
            self.args.h_coupling,
            n_free_angles=self.args.free_angles,
            sector=self.args.sector,
            backend=self.args.backend,
            circuit=self.args.circuit,
            noise_level=self.args.noise_level,
            pbc=self.args.pbc,
            rng=self.ctx.obj.rng,
            cache_fname=self.args.cache,
        )

    @final_property
    def acq_optimizer(self):
        params = self.args.acq_params

        acq_func_name = click.Choice(list(ACQUISITION_FNS))(getattr(params, 'func', 'lcb'))
        acq_func, acq_params = ACQUISITION_FNS[acq_func_name]
        acq_kwargs = {key: getattr(params, key) for key in acq_params}

        acquisition_optimizer_name = click.Choice(list(OPTIMIZER_SETUPS))(getattr(params, 'optim', 'oneshot'))
        optimizer_fn, optimizer_params = OPTIMIZER_SETUPS[acquisition_optimizer_name]
        optimizer_kwargs = {key: getattr(params, key) for key in optimizer_params}

        def candidate_sampler(score_fn=None):
            if score_fn is not None:
                raw_candidates = self.sampler.sample(
                    self.args.candidate_samples * self.args.candidate_shots, known=False
                )
                indices = torch.argsort(score_fn(raw_candidates), descending=True)[:self.args.candidate_samples]
                return raw_candidates[indices]
            return self.sampler.sample(self.args.candidate_samples, known=False)

        optimizer_kwargs.update({
            'acquisition_fn': acq_func(**acq_kwargs),
            'sampler': candidate_sampler,
            'n_readout': self.args.n_readout,
            'debug': self.args.debug
        })
        return optimizer_fn(**optimizer_kwargs)

    @final_property
    def kernel(self):
        self.args.kernel_params = Namespace(**self.args.kernel_params._asdict())
        if self.args.kernel_params.sigma_0 is None:
            logging.info(
                f'Value for sigma_0={self.args.kernel_params.sigma_0} is not valid. Set to default sigma_0=1 instead.'
            )
            self.args.kernel_params.sigma_0 = 1.0

        gamma = self.args.kernel_params.gamma
        sigma_0 = self.args.kernel_params.sigma_0

        kernel_kwargs = {
            'vqe': {'sigma_0': sigma_0, 'gamma': gamma},
            'sharedvqe': {'sigma_0': sigma_0, 'gamma': gamma, 'eta': 1.0, 'share': self.args.n_qbits},
            'rbf': {'sigma_0': sigma_0, 'gamma': gamma},
            'periodic': {'sigma_0': sigma_0, 'gamma': gamma},
            'matern': {'nu': 2.5},
            'torusrbf': {'sigma_0': sigma_0, 'gamma': gamma},
            'torusmatern': {'nu': 2.5},
        }[self.args.kernel]

        return KERNELS[self.args.kernel](**kernel_kwargs)

    @final_property
    def train_data(self):
        if hasattr(self.args, 'train_data') and self.args.train_data is not None:
            with h5py.File(self.args.train_data, 'r') as fd:
                if all(key in fd for key in ('x_train', 'y_train')):
                    return Data(
                        torch.from_numpy(fd['x_train'][()]),
                        torch.from_numpy(fd['y_train'][()]),
                    )
                elif all(key in fd for key in ('data', 'params')):
                    elems = {
                        'base/params': fd['params'][()],
                    }
                    with h5py.File(self.args.output_file, 'a') as fd2:
                        for key, val in elems.items():
                            fd2[f'{key}'] = val

                    index = slice(self.args.train_data_index, self.args.train_data_index + 1)
                    return Data(
                        torch.from_numpy(fd['data/angles'][index]),
                        torch.from_numpy(fd['data/energy'][index]),
                        readout=torch.from_numpy(fd['data/n_qc_readout'][index]),
                    )
                else:
                    raise RuntimeError('train-data is incompatible')
        return Data(*self.sampler.cached_sample(
            self.args.train_samples, key='train', force_compute=self.args.train_data_mode == 'compute'
        ))

    @final_property
    def true_solution(self):
        return TrueSolution(*self.sampler.exact_diag())

    @final_property
    def model(self):
        if self.args.y_var_default_estimates is not None:
            logging.info('Estimating default observation noise (y_var_default)...')
            self.args.y_var_default = self.sampler.estimate_variance(
                self.args.y_var_default_estimates, self.args.n_readout
            )
            logging.info(f'Estimated {self.args.y_var_default:0.2e} for default observation noise.')

        model = GaussianProcess(
            self.train_data.x,
            self.train_data.y,
            kernel=self.kernel,
            y_var_default=self.args.y_var_default,
            inducer=self.args.inducer,
            meta=[
                MeasureMeta(step=0, goal=MeasureGoal.INIT, readout=None if readout is None else readout.item())
                for readout in (
                    [None] * len(self.train_data.x) if self.train_data.readout is None else self.train_data.readout
                )
            ]
        )

        return model

    @final_property
    def bayes_opt(self):

        def append_json_log(state):
            '''Write state to json file (if option is enabled).'''
            if self.ctx.obj.json_log is None:
                return
            stlog = state.get('log', {})

            # indirect observables (only available depending on state)
            observables = {
                'n_qc_eval': state.get('n_qc_eval', 0),
                'n_qc_readout': state.get('n_qc_readout', 0),
                'y_best': float(state.get('y_best', 0.)),
                'y_start': float(stlog.get('y_est', 0.)),
                'y_true': stlog.get('y_true'),
                'fidelity': stlog.get('fidelity'),
                'corethresh': state.get('corethresh'),
                'gamma': stlog.get('gamma'),
                'k_last': stlog.get('k_last'),
            }

            # store observable log in json file
            with open(self.ctx.obj.json_log, 'a') as fd:
                json.dump(observables, fd)
                fd.write('\n')

        def append_hdf5_log(state):
            '''Write extended state to h5 file.'''
            stlog = state.get('log', {})

            # storables are appended and stored to the h5 file
            self.log.update({
                'true_energy': stlog.get('y_true'),  # (oracle) true energy at pivot point
                'y_std_meas': stlog.get('y_std_meas'),  # measured standard deviation at pivot point
                'overlap': stlog.get('fidelity'),  # (oracle) overlap with ground state
                'energy': stlog.get('y_est'),  # optimizer prediction at pivot point ( minimum of GP-mean of energy / minimum of NFT curve )
                'energy_uncertainty': stlog.get('y_est_std'),  # optimizer prediction std at pivot point ( minimum of GP-mean of energy / minimum of NFT curve )
                'x_measured': state.get('x_measured'),  # angles to be measured (shifted points wrt the pivot point)
                'y_measured': state.get('y_measured'),  # measured energy at x_meas
                'angles': state.get('x_start'),  # angles of pivot point
                'gamma': stlog.get('gamma'),  # optimized gamma value of kernel
                'n_qc_eval': state.get('n_qc_eval'),  # cumulative number of calls to energy evaluation function
                'n_qc_readout': state.get('n_qc_readout'),  # cumulative number of readouts (shots) per hamiltonian
                'k_last': stlog.get('k_last'),  # current direction (where last shift was measured)
                'corethresh': state.get('corethresh'),  # threshold for CoRe (kappa)
                # 'l_r': state.get('scheduler').get_last_lr(),  # learning rate
                'runtime': time.time() - self.start_time,  # runtime in seconds since initialization of BO
                # gp variance
                'gp_prior_var_max': stlog.get('gp_prior_var_max'),  # max gp-variance along k_last before measuring
                'gp_post_var_max': stlog.get('gp_post_var_max'),  # max gp-variance along k_last after measuring
                'pivot_prior_var': stlog.get('pivot_prior_var'),  # gp-variance at pivot point before measuring
                'pivot_post_var': stlog.get('pivot_post_var'),  # gp-variance at pivot point after measuring
                # true grad
                'x_psr_grad_1': stlog.get('x_psr_grad_1'),  # true gradient using parameter shift rule
                'x_psr_grad_1_noisless': stlog.get('x_psr_grad_1_noisless'),  # true gradient using parameter shift rule
                'x_psr_grad_1_bayes': stlog.get('x_psr_grad_1_bayes'),  # true gradient using parameter shift rule
                'x_grad_dir': stlog.get('x_grad_dir'),  # true gradient using parameter shift rule of current dir
                'x_grad_ldir': stlog.get('x_grad_ldir'),  # true gradient using parameter shift rule of last dir
                'x_grad_1': stlog.get('x_grad_1'),  # true gradient l1-norm
                'x_grad_2': stlog.get('x_grad_2'),  # true gradient l2-norm
                # autograd grads
                'pivot_prior_grad_1': stlog.get('pivot_prior_grad_1'),  # l1-gradnorm of gp-mean wrt. pivot point before
                'pivot_post_grad_1': stlog.get('pivot_post_grad_1'),  # l1-gradnorm of gp-mean wrt. pivot point after
                'pivot_prior_grad_2': stlog.get('pivot_prior_grad_2'),  # l2-gradnorm of gp-mean wrt. pivot point before
                'pivot_post_grad_2': stlog.get('pivot_post_grad_2'),  # l2-gradnorm of gp-mean wrt. pivot point after
                'pivot_prior_grad_dir': (
                    stlog.get('pivot_prior_grad_dir')  # grad of direction of gp-mean wrt. pivot point before
                ),
                'pivot_post_grad_dir': (
                    stlog.get('pivot_post_grad_dir')  # grad of direction of gp-mean wrt. pivot point after
                ),
                'pivot_prior_grad_ldir': (
                    stlog.get('pivot_prior_grad_ldir')  # grad of last direction of gp-mean wrt. pivot point before
                ),
                # grad-kernel grads
                'pivot_prior_kgrad_1': stlog.get('pivot_prior_kgrad_1'),  # l1-gradnorm of gp-mean wrt. pivot point befo
                'pivot_post_kgrad_1': stlog.get('pivot_post_kgrad_1'),  # l1-gradnorm of gp-mean wrt. pivot point after
                'pivot_prior_kgrad_2': stlog.get('pivot_prior_kgrad_2'),  # l2-gradnorm of gp-mean wrt. pivot point befo
                'pivot_post_kgrad_2': stlog.get('pivot_post_kgrad_2'),  # l2-gradnorm of gp-mean wrt. pivot point after
                'pivot_prior_kgrad_dir': (
                    stlog.get('pivot_prior_kgrad_dir')  # grad of direction of gp-mean wrt. pivot point before
                ),
                'pivot_post_kgrad_dir': (
                    stlog.get('pivot_post_kgrad_dir')  # grad of direction of gp-mean wrt. pivot point after
                ),
                'pivot_prior_kgrad_ldir': (
                    stlog.get('pivot_prior_kgrad_ldir')  # grad of last direction of gp-mean wrt. pivot point before
                ),
                'pivot_prior_kgradvar_2': (
                    stlog.get('pivot_prior_kgradvar_2')  # l2-gradnorm of gp-var wrt. pivot point befo
                ),
                'pivot_post_kgradvar_2': (
                    stlog.get('pivot_post_kgradvar_2')  # l2-gradnorm of gp-var wrt. pivot point after
                ),
                # align
                'k_align_mean': stlog.get('k_align_mean'),  # mean alignment
                'k_align_max': stlog.get('k_align_max'),  # mean alignment
            })
            # extend is required for storables that add more than one element per step
            self.log.update_extend({
                'x_train': stlog.get('x_meas'),  # all measured angles
                'y_train': stlog.get('y_meas'),  # measured energy values of all measured angles
                'y_train_var': stlog.get('y_var'),  # observation noise variance of all measured angles
                'step_train': [state['step']] * len(stlog['x_meas']) if 'x_meas' in stlog else None,  # step of measure
                'gp_prior_var_train': stlog.get('gp_prior_var'),  # gp-variance of measured points before measuring
                'gp_post_var_train': stlog.get('gp_post_var'),  # gp-variance of measured points after measuring
            })

        def append_stdout_log(state):
            def ifhas(template, obj, key=None):
                if key is not None:
                    obj = obj.get(key)
                if isinstance(obj, torch.Tensor):
                    obj = obj.item()
                return '' if obj is None else template.format(obj)

            stlog = state.get('log', {})

            outlog = [
                f'Step {state["step"]: 4d}',
                ifhas('eval {: 4d}', state, 'n_qc_eval'),
                ifhas('readout {: 9d}', state, 'n_qc_readout'),
                ifhas('est-energy: {:.3f}', stlog, 'y_est'),
                ifhas('true-energy: {:.3f}', stlog, 'y_true'),
                ifhas('fidelity: {:.3f}', stlog, 'fidelity'),
                ifhas('corethresh {:.3f}', state, 'corethresh'),
            ]

            for key in self.args.log:
                outlog.append(ifhas('%s: {:.3f}' % key, stlog, key))

            outlog = [val for val in outlog if val]
            logging.info('; '.join(outlog))

        def observe_fn(model, state, error):
            stlog = state.get('log', {})
            if 'x_start' in state:
                input = state['x_start'][None].requires_grad_()
                y_true = self.sampler.exact_energy(input)
                stlog['fidelity'] = self.sampler.exact_overlap(state['x_start'][None]).item()
                stlog['y_true'] = y_true.detach().item()
                stlog['y_est'] = state['y_start'].item()

                if 'y_start_std' in state:
                    stlog['y_est_std'] = state['y_start_std'].item()
                if self.args.debug.has('grad'):
                    x_grad = torch.autograd.grad(y_true, input)[0].squeeze(0)
                    x_psr_grad = torch.tensor(self.sampler.backend.parameter_shift_gradient(input.detach().numpy(), 1024))
                    x_psr_grad_noisless = torch.tensor(self.sampler.backend.parameter_shift_gradient(input.detach().numpy(), 0))
                    x_psr_grad_bayes = model.posterior_grad(input, diag=True).mean.squeeze(-1)
                    if 'k_best' in state:
                        stlog['x_grad_dir'] = x_grad[state['k_best']]
                        stlog['x_grad_ldir'] = x_grad[shape_step(state['k_best'], state['x_start'].shape, -1)]

                    stlog['x_psr_grad_1'] = x_psr_grad.abs().mean().detach().numpy()
                    stlog['x_psr_grad_1_noisless'] = x_psr_grad_noisless.abs().mean().detach().numpy()
                    stlog['x_psr_grad_1_bayes'] = x_psr_grad_bayes.abs().mean().detach().numpy()
                    stlog['x_grad_1'] = x_grad.abs().mean()
                    stlog['x_grad_2'] = (x_grad ** 2).mean() ** .5

                if self.args.debug.has('align') and 'k_best' in state:
                    # compute alignment
                    lins = torch.linspace(0, math.tau, 101)[:-1]
                    x_line = make_shift(state['x_start'], lins, state['k_best'])
                    t_line, _ = self.sampler.exact_line(state['x_start'], lins, state['k_best'], force_compute=True)
                    d_line = model.posterior(x_line, diag=True)
                    align = (t_line - d_line.mean).abs() / d_line.std
                    stlog['k_align_mean'] = align.mean()
                    stlog['k_align_max'] = align.max()

            if 'k_best' in state:
                stlog['k_last'] = int(np.ravel_multi_index(state['k_best'], state['x_start'].shape))
            if 'gamma' in model.kernel.param_dict():
                stlog['gamma'] = model.kernel.param_dict().get('gamma')

            append_stdout_log(state)
            append_json_log(state)
            append_hdf5_log(state)

        @wrap_logger(observe_fn)
        def rms_error_fn(model, state):
            return 0.0

        if self.args.y_var_default == 0.0 and self.args.acq_params.stabilize_interval:
            raise RuntimeError(
                'Exact GPs cannot be used with a stabilization interval! The resulting Gram matrix will become non-PD!'
            )

        return BayesianOptimization(
            model=self.model,
            optimizer=self.acq_optimizer,
            sampler=self.sampler,
            error_fn=rms_error_fn,
            cheat=self.args.cheat,
            var=self.args.var_mode,
        )

    @final_property
    def interval(self):
        return interval_schedule(self.args.hyperopt.interval)

    @final_property
    def lossfn(self):
        return {
            'loo': self.model.loocv_mll_closed,
            'mll': self.model.log_likelihood,
        }[self.args.hyperopt.loss]

    @final_property
    def kernel_optim(self):
        if self.args.hyperopt.optim != 'adam':
            raise RuntimeError(f'Cannot use kernel_optim with {self.args.hyperopt.optim}!')
        params = list(self.kernel.parameters())
        for param in params:
            param.requires_grad = True
        if not params:
            logging.warning(f'{self.kernel} has no hyper parameters to optimize!')
        return torch.optim.Adam(params, lr=self.args.hyperopt.lr)

    @final_property
    def n_iter(self):
        '''An estimate for the number of iterations (only exact for iter-mode 'step').'''
        if self.args.iter_mode == 'qc':
            return (self.args.n_iter - self.args.train_samples) // 2
        elif self.args.iter_mode == 'readout':
            return ((self.args.n_iter - self.args.train_samples * self.args.n_readout) // (self.args.n_readout // 2))

        return self.args.n_iter

    def __len__(self):
        return self.n_iter

    def __iter__(self):
        for i in count():
            if (
                (self.args.iter_mode == 'qc' and self.bayes_opt.optim_state['n_qc_eval'] >= self.args.n_iter)
                or (self.args.iter_mode == 'readout' and self.bayes_opt.optim_state['n_qc_readout'] >= self.args.n_iter)
                or (self.args.iter_mode != 'qc' and i > self.args.n_iter)
            ):
                return
            yield i

    def store(self, step, initial=False):
        output_file = self.args.output_file
        checkpoint_file = self.args.checkpoint_file

        if initial:
            initial_dict = {
                'data/true_e0': self.true_solution.e0,
                'data/true_e1': self.true_solution.e1,
                'params': json.dumps(namedtuple_as_dict(self.ctx.params)),
            }
            init_files = [output_file]
            if checkpoint_file is not None:
                init_files.append(checkpoint_file)
            for init_file in init_files:
                with h5py.File(init_file, 'w') as fd:
                    for key, val in initial_dict.items():
                        fd[f'{key}'] = val
                logging.info(f'Initialized \'{init_file}\'.')

            self.log.update_extend({
                'x_train': list(self.train_data.x),
                'y_train': list(self.train_data.y),
                'y_train_var': [self.model.y_var_default] * len(self.train_data.x),
            })

        self.log.flush()
        logging.info(f'Flushed results to \'{output_file}\'.')
        logging.info(f'Size of GP: \'{len(self.model)}\'.')

        if checkpoint_file is not None:
            state_dict = self.model.state_dict()
            with h5py.File(checkpoint_file, 'a') as fd:
                for key, val in state_dict.items():
                    h5key = f'state/{step:d}/gp/{key}'
                    if h5key in fd:
                        del fd[h5key]
                    fd[h5key] = val
            logging.info(f'Saved state to \'{checkpoint_file}\'.')

    def hyperopt(self, step):
        if self.args.hyperopt.optim == 'grid':
            if self.interval(step):
                max_gamma = self.args.hyperopt.max_gamma
                wiggle = self.ctx.obj.rng.normal(0, (max_gamma - 1) / self.args.hyperopt.steps)
                grid_search_gamma(
                    self.model,
                    min_gamma=np.sqrt(2.),
                    max_gamma=max_gamma + wiggle,
                    num=self.args.hyperopt.steps,
                    loss=self.args.hyperopt.loss
                )
        elif self.args.hyperopt.optim == 'adam':
            if self.interval(step):
                for step in range(self.args.hyperopt.steps):
                    self.kernel_optim.zero_grad()
                    loss = self.lossfn()
                    loss.backward()
                    self.kernel_optim.step()
                    self.model.reinit()


@main.command('train')
@click.argument('output_file', type=click.Path(writable=True))
@click.option('--train-data', type=click.Path(exists=True))
@click.option('--train-data-index', type=int, default=0)
@click.option('--flush-interval', type=int, default=50)
@click.option('--checkpoint', 'checkpoint_file', type=click.Path(writable=True))
@click.option('--meta', type=option_dict, help='Meta-info to store in params.')
@click.option('--log', type=csobj(str), default='', help='Extra values to log on stdout')
@QCParams.options()
@GPParams.options()
@BOParams.options()
@click.pass_context
def train(ctx, **kwargs):
    args = Namespace(**kwargs)
    ns = BayesOptCLI(args, ctx)

    # initialize storage
    ns.store(0, initial=True)

    try:
        for bayes_step in ns:
            ns.hyperopt(bayes_step)
            ns.bayes_opt.step(bayes_step)
            if args.flush_interval and ((bayes_step + 1) % args.flush_interval == 0):
                ns.store(bayes_step + 1)
    finally:
        ns.store(bayes_step + 1)
        logging.info('Bayesian Optimization ended successfully')


@main.command('make-train')
@click.argument('output')
@QCParams.options()
@GPParams.options()
@BOParams.options()
@click.pass_context
def make_train(ctx, **kwargs):
    args = Namespace(**kwargs)
    args.train_data_mode = 'cache'
    ns = BayesOptCLI(args, ctx)

    with h5py.File(args.output, 'w') as fd:
        fd['x_train'] = ns.train_data.x
        fd['y_train'] = ns.train_data.y

    if args.cache is None:
        logging.info('Nothing initialized!')
    elif ns.train_data is not None and ns.true_solution is not None:
        logging.info(f'Initialized {args.cache}.')


if __name__ == '__main__':
    main()
